import numpy as np


def synthetic_pam_attn(topic_matrix, twords=1, lam=30.0, n_docs=100):
    K, V = topic_matrix.shape
    K_s = int(K / 2)

    ret_docs, Y = [], []

    doc_lengths = np.repeat(lam, n_docs)

    for i in range(n_docs):
        document = []
        super_topics = np.random.dirichlet((1 / K_s) * np.ones(K_s))
        supertopic_topics = correlated_supertopic_topic(K, K_s, 30)
        topic_dist = super_topics.dot(supertopic_topics)

        doc_topics = np.random.choice(K, size=doc_lengths[i], replace=True, p=topic_dist)
        for j in range(doc_lengths[i]):
            document.append(np.random.choice(V, size=1, p=topic_matrix[doc_topics[j], :]).item())

        ret_docs.append(document)

        extra_topics = np.random.choice(K, size=twords * 2, replace=True,
                                        p=topic_dist)
        target_words = []
        for t in range(twords):
            target_words.append(np.random.choice(V, size=1, p=topic_matrix[extra_topics[2 * t], :]).item())

        Y.append(target_words)

    return (ret_docs, np.array(Y))

def synthetic_pam_two_targets(topic_matrix, twords=1, lam=30.0, n_docs=100):
    K, V = topic_matrix.shape
    K_s = int(K / 2)

    ret_docs, Y = [], []

    doc_lengths = np.repeat(lam, n_docs)

    for i in range(n_docs):
        document = []
        super_topics = np.random.dirichlet((1 / K_s) * np.ones(K_s))
        supertopic_topics = correlated_supertopic_topic(K, K_s, 30)
        topic_dist = super_topics.dot(supertopic_topics)

        doc_topics = np.random.choice(K, size=doc_lengths[i], replace=True, p=topic_dist)
        for j in range(doc_lengths[i]):
            document.append(np.random.choice(V, size=1, p=topic_matrix[doc_topics[j], :]).item())

        ret_docs.append(document)

        extra_topics = np.random.choice(K, size=twords * 2, replace=True,
                                        p=topic_dist)  
        target_words = []
        for t in range(twords):
            first_target = np.random.choice(V, size=1, p=topic_matrix[extra_topics[2 * t], :]).item()
            second_target = np.random.choice(V, size=1, p=topic_matrix[extra_topics[2 * t + 1], :]).item()
            target_index = first_target + second_target * V
            target_words.append(target_index)

        Y.append(target_words)

    return (ret_docs, np.array(Y))

def correlated_supertopic_topic(K=20, K_s=10, dirich_param=30):
    supertopic_topics = []
    dirich_sample = np.random.dirichlet([dirich_param,dirich_param])
    for j in range(0,K,int(K/K_s)):
        w=np.zeros(K)
        w[j] = dirich_sample[0]
        w[j+1] = dirich_sample[1]
        supertopic_topics.append(w)
    return np.array(supertopic_topics)

def sample_reuse(train_doc_list, twords=1):
    # train_doc_list has shape (n_docs, lam+twords)
    ret_docs, Y = [], []

    for doc in train_doc_list:
        target_indices = np.random.choice(len(doc), size=twords, replace=False)
        doc = np.array(doc)
        Y.append(doc[target_indices])
        input_indices = np.ones(len(doc), dtype=bool)
        input_indices[target_indices] = False
        ret_docs.append(doc[input_indices])

    return ret_docs, Y